Introduction¶

TP53, known as the Guardian of the Genome, is the most commonly mutated gene in human cancers. Understanding its regulation is critical. In the "A machine learning and directed network optimization approach to uncover TP53 regulatory patterns" two strategies were used to explore this: machine learning to predict TP53 mutation status from transcriptomic data, and directed regulatory networks to analyze the impact of these mutations on TP53 target gene expression.

In this notebook we provide our own solution to the first of the two problems, trying to understand the mutation type of a certain cell, as well as if a certain mutation compromises the functionality of the cell itself.

Exploratory Data Analysis¶

Data comes from the The Cancer Genome Atlas Program and contains different cancer cells where TP53 is mutated. The data contains information about the type of mutation (missense or others) and wheter the cell was compromised.

In [ ]:
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sklearn

import json
import os
import sys
import warnings

import torch
import scanpy as sc
import networkx as nx
import tqdm
import gseapy as gp

import torchtext

torchtext.disable_torchtext_deprecation_warning()

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed

os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')
/home/dario/PycharmProjects/ML-lab/venv/lib/python3.12/site-packages/scgpt/model/model.py:21: UserWarning: flash_attn is not installed
  warnings.warn("flash_attn is not installed")
/home/dario/PycharmProjects/ML-lab/venv/lib/python3.12/site-packages/scgpt/model/multiomic_model.py:19: UserWarning: flash_attn is not installed
  warnings.warn("flash_attn is not installed")
In [ ]:
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df.head()
Variant_Classification ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 ... ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134 is_true mutation
0 A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... 376.831000 1358.86000 2471.580000 143602.00000 159.674000 63.136500 946.639000 626.477000 344.195000 ... 323.344000 75.356400 8558.040000 43.991900 1783.300000 5320.570000 1018.330000 821.181000 True Frame_Shift_Ins
1 A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... 198.244448 5367.62179 2528.570328 77726.97678 19.656121 2.579692 2130.976296 732.991931 386.605718 ... 228.638412 322.247574 6446.509718 36.542642 3207.438557 3213.116903 1688.261865 1149.407697 True In_Frame_Del
2 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 117.516000 1936.34000 14533.700000 185841.00000 95.490700 191.866000 766.578000 256.410000 239.611000 ... 230.672000 121.132000 12726.800000 74.270600 2496.910000 4005.300000 923.961000 391.689000 True Frame_Shift_Del
3 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 60.747000 5667.60000 3560.420000 107645.00000 86.834700 1047.620000 698.413000 186.741000 262.372000 ... 638.609000 343.604000 8024.280000 78.431400 3746.030000 2692.810000 1168.070000 670.402000 True Frame_Shift_Del
4 A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... 327.477000 1096.61000 3430.480000 64166.60000 51.837300 9.491300 706.010000 1617.540000 821.366000 ... 806.811000 124.118000 1350.690000 237.649000 1885.860000 2283.400000 1967.630000 480.043000 True Frame_Shift_Del

5 rows × 554 columns

Basic Exploration¶

In [ ]:
import sklearn
from sklearn.decomposition import PCA
from copy import deepcopy
In [ ]:
df_full = deepcopy(df)
df = df.drop(columns=['is_true', 'mutation', 'Variant_Classification'])
In [ ]:
df.head()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
0 376.831000 1358.86000 2471.580000 143602.00000 159.674000 63.136500 946.639000 626.477000 344.195000 435.438000 ... 1820.37000 264.358000 323.344000 75.356400 8558.040000 43.991900 1783.300000 5320.570000 1018.330000 821.181000
1 198.244448 5367.62179 2528.570328 77726.97678 19.656121 2.579692 2130.976296 732.991931 386.605718 1185.830576 ... 527.55836 16.858995 228.638412 322.247574 6446.509718 36.542642 3207.438557 3213.116903 1688.261865 1149.407697
2 117.516000 1936.34000 14533.700000 185841.00000 95.490700 191.866000 766.578000 256.410000 239.611000 1976.130000 ... 1439.43000 164.456000 230.672000 121.132000 12726.800000 74.270600 2496.910000 4005.300000 923.961000 391.689000
3 60.747000 5667.60000 3560.420000 107645.00000 86.834700 1047.620000 698.413000 186.741000 262.372000 738.562000 ... 2136.32000 414.566000 638.609000 343.604000 8024.280000 78.431400 3746.030000 2692.810000 1168.070000 670.402000
4 327.477000 1096.61000 3430.480000 64166.60000 51.837300 9.491300 706.010000 1617.540000 821.366000 1041.860000 ... 1168.53000 183.986000 806.811000 124.118000 1350.690000 237.649000 1885.860000 2283.400000 1967.630000 480.043000

5 rows × 551 columns

In [ ]:
df.describe()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
count 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 ... 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000
mean 185.589301 3954.889567 7200.974864 114439.524334 207.436478 147.286450 825.070216 592.848118 442.641521 1094.332390 ... 1777.897799 534.550961 412.665719 255.096842 7621.347871 126.568887 3142.460706 2919.691552 1307.240249 661.287621
std 148.231968 3147.431047 17236.982170 48604.774763 241.214042 262.220737 560.475285 350.756481 239.295632 749.287419 ... 1141.164258 740.663305 239.702838 231.295388 4516.262873 109.098272 2018.346330 1489.464482 798.890198 325.307174
min 8.051500 29.764800 30.048173 24218.900000 0.000000 -0.210767 20.852200 58.117025 8.992800 57.971000 ... 42.885000 -0.319802 33.801400 0.000000 278.652000 0.399949 96.363800 339.315000 64.507200 81.250000
25% 103.270500 1815.710000 1221.666028 79320.416245 39.071550 15.592600 519.444000 364.062373 276.480676 553.878500 ... 1040.665390 83.612450 259.671500 117.194500 4516.985000 56.959424 1839.472102 1958.070000 791.811500 459.930000
50% 151.606000 3167.505717 2886.380000 105861.297800 123.480903 48.509415 697.903000 523.519000 400.155965 906.412000 ... 1546.240000 263.096000 365.717000 189.090000 6659.250000 99.149100 2705.100000 2640.281167 1107.065208 592.370060
75% 224.627000 5211.564607 6516.566190 139285.500000 298.676000 154.258000 965.818000 734.935000 559.460000 1452.640926 ... 2216.760000 662.011282 504.985500 304.513576 9518.812363 163.171022 3862.868994 3476.665000 1606.385000 773.011315
max 3918.930000 34507.870810 252607.000000 478521.000000 3115.940000 4181.060000 11942.500000 6406.190000 2910.670000 8566.567356 ... 11829.000000 8530.350000 3955.320000 4229.350000 37363.300000 1614.630000 26002.200000 22667.957440 12483.000000 4091.275370

8 rows × 551 columns

Cumulative distribution function of a few genes:

In [ ]:
L = 2
fig, axs = plt.subplots(L, L)
for i in range(L):
    for j in range(L):
        gene = df.columns[i*L + j]
        gene_data = df[gene].sort_values()
        gene_data = gene_data.reset_index(drop=True)
        ecdf = sm.distributions.ECDF(gene_data)
        axs[i, j].step(ecdf.x, ecdf.y)
plt.show()
No description has been provided for this image

We notice that the gene reads span mulitple orders of magnitude. There are some outliers, as can be seen from the quantiles and from the curves above.

We visualize the distribution of a few cells and genes to get a feeling for the data, using violin plots.

In [ ]:
L = 2
fig, axs = plt.subplots(L, L, figsize=(10, 10))
for i in range(L):
    for j in range(L):
        sns.violinplot(x=df.iloc[i * L + j, :], ax=axs[i, j])
        axs[i, j].set_title('Gene Expression profile of cell {}'.format(i * L + j))
No description has been provided for this image
In [ ]:
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
No description has been provided for this image
In [ ]:
L = 2
fig, axs = plt.subplots(L, L, figsize=(10, 10))
for i in range(L):
    for j in range(L):
        sns.violinplot(x=df.iloc[i * L + j, :], ax=axs[i, j])
        axs[i, j].set_title('Expression of gene {} across cells'.format(i * L + j))
No description has been provided for this image
In [ ]:
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
No description has been provided for this image

We check the sparsity level of the gene expression profiles. With some single cell sequencing datasets, it happens that the examples are extremely sparse. This is important to take into account, both for computational efficiency and to correctly interpret statistics. In our case, however, it seems that the sparsity level is small:

In [ ]:
print("sparsity level:", (df == 0).sum().sum() / df.size)
sparsity level: 0.02183073369763143

As noticed above, the distributions cover a very long range, spanning multiple orders of magnitude. This might be an indication that the data is more naturally handled in log scale. We proceed to log the data.

In [ ]:
df = df + 1
df = df.apply(np.log)
df.head()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
0 5.934447 7.215137 7.813017 11.874808 5.079377 4.161014 6.853974 6.441707 5.844109 6.078646 ... 7.507344 5.581080 5.781805 4.335412 9.054743 3.806482 7.486781 8.579524 6.926901 6.711961
1 5.294532 8.588327 7.835805 11.260971 3.028012 1.275277 7.664805 6.598498 5.959989 7.079042 ... 6.270153 2.882507 5.436506 5.778419 8.771449 3.625477 8.073540 8.075308 7.432047 7.047872
2 4.775048 7.569071 9.584294 12.132652 4.569447 5.261996 6.643240 5.550670 5.483182 7.589402 ... 7.272697 5.108705 5.445323 4.805102 9.451544 4.321090 7.823210 8.295623 6.829752 5.973018
3 4.123045 8.642697 8.177915 11.586603 4.475457 6.955230 6.550241 5.235063 5.573567 6.606058 ... 7.667308 6.029641 6.460857 5.842396 8.990352 4.374894 8.228719 7.898712 7.063964 6.509368
4 5.794467 7.000890 8.140747 11.069254 3.967217 2.350546 6.561045 7.389280 6.712186 6.949722 ... 7.064357 5.220280 6.694328 4.829257 7.209111 5.474994 7.542669 7.733859 7.585093 6.175957

5 rows × 551 columns

In [ ]:
df.describe()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
count 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 ... 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000 4211.000000
mean 5.027813 7.974901 7.969045 11.565688 4.650616 3.932039 6.578123 6.240796 5.955885 6.785079 ... 7.302806 5.368857 5.888605 5.255410 8.778089 4.571372 7.876275 7.872595 7.024809 6.400346
std 0.626356 0.858968 1.283715 0.404949 1.313455 1.525478 0.506659 0.544390 0.548099 0.671384 ... 0.623959 1.578166 0.523837 0.762479 0.580441 0.760368 0.606623 0.459292 0.548695 0.428670
min 2.202930 3.426371 3.435540 10.094930 0.000000 -0.236694 3.084302 4.079519 2.301865 4.077046 ... 3.781573 -0.385372 3.549658 0.000000 5.633546 0.336436 4.578454 5.829872 4.182160 4.409763
25% 4.646988 7.504782 7.108789 11.281263 3.690667 2.808957 6.254682 5.900068 5.625751 6.318749 ... 6.948576 4.438081 5.563261 4.772332 8.415821 4.059743 7.517777 7.580225 6.675585 6.133246
50% 5.027859 8.061015 7.968105 11.569894 4.824152 3.902163 6.549512 6.262482 5.994350 6.810597 ... 7.344228 5.576313 5.904590 5.247498 8.803912 4.606660 7.903264 7.879019 7.010371 6.385818
75% 5.418883 8.558827 8.782255 11.844288 5.702702 5.045088 6.874010 6.601142 6.328758 7.281827 ... 7.704253 6.496792 6.226508 5.721994 9.161130 5.100909 8.259424 8.154116 7.382364 6.651586
max 8.273829 10.448972 12.439594 13.078457 8.044607 8.338559 9.387942 8.765176 7.976482 9.055739 ... 9.378394 9.051503 8.283070 8.350040 10.528471 7.387480 10.165975 10.028752 9.432203 8.316856

8 rows × 551 columns

Now let's visualize the new data after applying the logarithm:

In [ ]:
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
No description has been provided for this image
In [ ]:
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
No description has been provided for this image

To prepare for unsupervised and supervised machine learning algorithms, we also normalize the gene expressions, gene by gene. In fact, as we see from the graph above, different genes have different means and standard deviations, so algorithms (e.g., kmeans or logistic regression) risk focusing on some more than others, since they exhibit larger variation.

In [ ]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
array = scaler.fit_transform(df)
df = pd.DataFrame(array, columns=df.columns)
df.head()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
0 1.447646 -0.884613 -0.121559 0.763446 0.326477 0.150118 0.544515 0.369101 -0.203957 -1.052329 ... 0.327846 0.134490 -0.203906 -1.206731 0.476685 -1.006067 -0.642145 1.539354 -0.178460 0.727020
1 0.425878 0.714228 -0.103805 -0.752572 -1.235518 -1.741801 2.145054 0.657148 0.007488 0.437898 ... -1.655197 -1.575655 -0.863156 0.686013 -0.011440 -1.244144 0.325224 0.441411 0.742282 1.510726
2 -0.403597 -0.472518 1.258411 1.400254 -0.061806 0.871933 0.128538 -1.267854 -0.862544 1.198149 ... -0.048260 -0.164864 -0.846323 -0.590653 1.160385 -0.329199 -0.087487 0.921154 -0.355537 -0.996989
3 -1.444666 0.777533 0.162726 0.051656 -0.133374 1.982035 -0.055037 -1.847666 -0.697616 -0.266677 ... 0.584245 0.418754 1.092553 0.769930 0.365736 -0.258430 0.581062 0.056869 0.071368 0.254357
4 1.224136 -1.134066 0.133770 -1.226063 -0.520368 -1.036843 -0.033712 2.109920 1.380025 0.245258 ... -0.382200 -0.094157 1.538300 -0.558970 -2.703400 1.188542 -0.550005 -0.302103 1.021242 -0.523517

5 rows × 551 columns

In [ ]:
df.describe()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZBTB38..ENSG000253461 ZBTB7C..ENSG000201501 ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134
count 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 ... 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03 4.211000e+03
mean 1.554049e-15 -4.758325e-16 -3.678421e-16 -2.682885e-16 -5.239219e-16 -3.610927e-16 1.906705e-16 -5.264529e-16 2.902241e-16 1.149085e-15 ... 4.302740e-17 1.147397e-16 4.049638e-16 -9.449155e-17 3.610927e-16 -9.837246e-16 6.665029e-16 1.737126e-15 -1.404718e-16 1.181144e-16
std 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 ... 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00 1.000119e+00
min -4.510562e+00 -5.295972e+00 -3.531972e+00 -3.632389e+00 -3.541172e+00 -2.733064e+00 -6.896626e+00 -3.970557e+00 -6.667505e+00 -4.033989e+00 ... -5.644040e+00 -3.646584e+00 -4.465560e+00 -6.893345e+00 -5.418151e+00 -5.570250e+00 -5.437008e+00 -4.448082e+00 -5.181365e+00 -4.644180e+00
25% -6.080725e-01 -5.473712e-01 -6.702104e-01 -7.024539e-01 -7.309452e-01 -7.363043e-01 -6.384559e-01 -6.259623e-01 -6.023963e-01 -6.946628e-01 ... -5.677806e-01 -5.898535e-01 -6.211531e-01 -6.336376e-01 -6.241986e-01 -6.729505e-01 -5.910434e-01 -6.366437e-01 -6.365389e-01 -6.231641e-01
50% 7.394250e-05 1.002653e-01 -7.327976e-04 1.038946e-02 1.321378e-01 -1.958719e-02 -5.647691e-02 3.984025e-02 7.018777e-02 3.801168e-02 ... 6.639363e-02 1.314692e-01 3.051905e-02 -1.037819e-02 4.449485e-02 4.641426e-02 4.449540e-02 1.398821e-02 -2.631787e-02 -3.389436e-02
75% 6.244316e-01 6.798808e-01 6.335571e-01 6.880706e-01 8.011018e-01 7.297264e-01 5.840663e-01 6.620046e-01 6.803826e-01 7.399737e-01 ... 6.434628e-01 7.147972e-01 6.451297e-01 6.120030e-01 6.599933e-01 6.965039e-01 6.316853e-01 6.130188e-01 6.517230e-01 5.861629e-01
max 5.182997e+00 2.880625e+00 3.482923e+00 3.736147e+00 2.584326e+00 2.888960e+00 5.546441e+00 4.637628e+00 3.686992e+00 3.382461e+00 ... 3.326874e+00 2.333775e+00 4.571553e+00 4.059123e+00 3.015966e+00 3.704052e+00 3.774953e+00 4.695084e+00 4.388014e+00 4.471364e+00

8 rows × 551 columns

Data visualization after normalization:

In [ ]:
df_small = df.iloc[:10, :].T
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of a sample of gene expression profiles")
plt.show()
No description has been provided for this image
In [ ]:
df_small = df.iloc[:, :10]
plt.figure(figsize=(16, 6))
sns.violinplot(data=df_small, palette='Set3', cut=0)
plt.xticks(rotation=90)
plt.title("Distributions of the expression levels across cells for a sample of genes")
plt.show()
No description has been provided for this image

Dimensionality Reduction¶

In order to visualize the data, and to check whether any interesting patterns are visible, we use two dimensionality reduction techniques.

First, we resort to an explainable and simple linear technique, PCA, to find a set of few orthogonal directions that explain most of the variance of the data. We can also inspect the principal components, which represent 'metagenes' along which our dataset exhibits large variance, and may turn out to be useful for classification as well.

Then, we employ a more powerful nonlinear dimensionality reduction technique, tSNE, which was designed specifically for data visualization, to get a more accurate depiction of the local structure of our data.

PCA¶

In [ ]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
pca.fit(df)
array_pca = pca.transform(df)
df_pca = pd.DataFrame(array_pca, columns=['PC1', 'PC2'])
In [ ]:
plt.figure(figsize=(10, 10))
sns.scatterplot(x='PC1', y='PC2', data=df_pca, hue=df_full['is_true'], alpha=0.5)
plt.title("PCA of the gene expression profiles")
plt.show()
No description has been provided for this image
In [ ]:
array_reconstructed = pca.inverse_transform(array_pca)
mse = ((array - array_reconstructed) ** 2).mean()
print("MSE of the reconstruction:", mse)
MSE of the reconstruction: 0.75991853023862
In [ ]:
explained_variance = pca.explained_variance_ratio_.sum()
print("Explained variance:", explained_variance)
Explained variance: 0.24008146976137446

With only two directions, only 24% of the variance is explained. The reconstruction error is high (0.76, with data having features with standard deviation 1 and mean 0). We try to increase the number of principal components, and monitor how the explained variance increases.

In [ ]:
pca = PCA(n_components=551)
pca.fit(df)
array_pca = pca.transform(df)
var = pca.explained_variance_ratio_[0:20]
labels = ["PC"+str(i+1) for i in range(20)]
plt.figure(figsize=(16,4))
plt.bar(labels, var)
plt.xlabel('Principal Components')
plt.ylabel('Proportion of Variance Explained')
plt.show()
No description has been provided for this image
In [ ]:
cum_var = np.cumsum(pca.explained_variance_ratio_)
plt.plot(cum_var)
plt.hlines(0.72, 0, 551, colors='r', linestyles='dashed', label='0.72 explained variance')
plt.xlabel('Number of Principal Components')
plt.ylabel('Cumulative Proportion of Variance Explained')
plt.legend()
plt.show()
No description has been provided for this image
In [ ]:
np.argmax(cum_var > 0.72)  # hacky
49

As is often the case, most of the variance is explained using relatively few components. There seems to be an ellbow around 49 components. These plots will be useful in the classification part, using PCA to extract features from data.

tSNE¶

As the results of tSNE are sensible to the choice of hyperparameters, and especially of perplexity, we try a few configurations in a grid and go for the hyperparameters that yield the most stable results.

In [ ]:
from sklearn.manifold import TSNE

L = 3
fig, axs = plt.subplots(L, L, figsize=(15, 15))
perplexities = [5, 10, 15, 20, 25, 30, 35, 40, 45]
for i in range(L):
    for j in range(L):
        perplexity = perplexities[i * L + j]
        tsne = TSNE(n_components=2, perplexity=perplexity, max_iter=1000)
        array_tsne = tsne.fit_transform(df)
        df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])
        sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, alpha=0.5, ax=axs[i, j])
        axs[i, j].set_title('Perplexity = {}'.format(perplexity))
No description has been provided for this image

Results look crispier and more stable for values of the perplexity in the range 15-35. We go for 25. Notice that some clusters seem to emerge. We will investigate further later, using clustering techniques, to see whether this is an artifact of dimensionality reduction or whether, even in the higher dimensional space, clusters emerge naturally. Now, let's see whether any structure emerges by coloring points according to their labels - first functional vs non functional and then mutation type.

In [ ]:
tsne = TSNE(n_components=2, perplexity=25, max_iter=1000)
array_tsne = tsne.fit_transform(df)
df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])

plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_full['is_true'], alpha=0.5)
plt.title("tSNE of the gene expression profiles")
plt.show()
No description has been provided for this image
In [ ]:
#plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_full['mutation'], alpha=0.5, palette='Set1')
plt.title("tSNE of the gene expression profiles")
plt.show()
No description has been provided for this image

Clustering¶

In [ ]:
gene_columns = [col for col in df.columns if col not in ['is_true', 'mutation', 'Variant_Classification']]
X = df[gene_columns].to_numpy()
In [ ]:
X.shape
(4211, 551)
In [ ]:
# for visualization later

from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

tsne = TSNE(n_components=2, perplexity=25, max_iter=1000)
array_tsne = tsne.fit_transform(df[gene_columns])
df_tsne = pd.DataFrame(array_tsne, columns=['tSNE1', 'tSNE2'])

We want to perform clustering on our data. Gene expression profiles live in a 500-dimensional space. Since clustering is entirely based on a notion of distance in the data space, it's important that the used metric be meaningful. However, because of the well known phenomenon of the curse of dimensionality, distances tend to lose meaning as the dimension of the ambient space increases (there are counter intuitive phenomena like concentration of volume on boundaries, quasi-orthogonality of random vectors and many other well known examples happening). For this reason, we carry out clustering on dimensionally reduced data, using PCA. For the number of components, we use the one we highlighted as ellbow, since it strikes a good trade off between information retain and dimensionality.

In [ ]:
pca = PCA(n_components=50)
pca.fit(X)
X_pca = pca.transform(X)

To choose the number of clusters, which must be provided as input to Kmeans, we use two indicators: Silhouette score and inertia.

In [ ]:
from sklearn.cluster import KMeans
inertia = []
for k in range(1, 51):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    inertia.append(kmeans.inertia_)
plt.plot(range(1, 51), inertia)
[<matplotlib.lines.Line2D at 0x2a6348b50>]
No description has been provided for this image
In [ ]:
# silhouette score
from sklearn.metrics import silhouette_score

silhouette = []
for k in range(2, 51):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    silhouette.append(silhouette_score(X, kmeans.labels_))
plt.plot(range(2, 51), silhouette)
[<matplotlib.lines.Line2D at 0x2a637f8d0>]
No description has been provided for this image
In [ ]:
2 + np.argmax(silhouette)
3

The number of clusters maximizing the Silhouette score is k = 6. The inertia gives little indication, as no clear elbow is visible. We go with k = 6.

In [ ]:
k = 6
kmeans = KMeans(n_clusters=k)
kmeans.fit(X_pca)
df_tsne['cluster'] = kmeans.labels_

The Silhouette score is low so the quality of the clusters is not optimal. Since the value is near 0, we expect to have poorly clustered points

In [ ]:
plt.figure(figsize=(10, 10))
sns.scatterplot(x='tSNE1', y='tSNE2', data=df_tsne, hue=df_tsne['cluster'], alpha=0.5, palette='Set1')
plt.title("tSNE of the gene expression profiles")
plt.show()
No description has been provided for this image

Correlation for feature selection¶

To find the most promising genes to consider for classification, we calculate the correlation of the labels with each single gene. We use the point biserial correlation coefficient, since we are dealing with a continuous and a binary variable. This is a special case of the Pearson correlation coefficient.

In [ ]:
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df.head()
Variant_Classification ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 ... ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134 is_true mutation
0 A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... 376.831000 1358.86000 2471.580000 143602.00000 159.674000 63.136500 946.639000 626.477000 344.195000 ... 323.344000 75.356400 8558.040000 43.991900 1783.300000 5320.570000 1018.330000 821.181000 True Frame_Shift_Ins
1 A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... 198.244448 5367.62179 2528.570328 77726.97678 19.656121 2.579692 2130.976296 732.991931 386.605718 ... 228.638412 322.247574 6446.509718 36.542642 3207.438557 3213.116903 1688.261865 1149.407697 True In_Frame_Del
2 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 117.516000 1936.34000 14533.700000 185841.00000 95.490700 191.866000 766.578000 256.410000 239.611000 ... 230.672000 121.132000 12726.800000 74.270600 2496.910000 4005.300000 923.961000 391.689000 True Frame_Shift_Del
3 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 60.747000 5667.60000 3560.420000 107645.00000 86.834700 1047.620000 698.413000 186.741000 262.372000 ... 638.609000 343.604000 8024.280000 78.431400 3746.030000 2692.810000 1168.070000 670.402000 True Frame_Shift_Del
4 A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... 327.477000 1096.61000 3430.480000 64166.60000 51.837300 9.491300 706.010000 1617.540000 821.366000 ... 806.811000 124.118000 1350.690000 237.649000 1885.860000 2283.400000 1967.630000 480.043000 True Frame_Shift_Del

5 rows × 554 columns

In [8]:
def log_and_normalize(df: pd.DataFrame) -> pd.DataFrame:
    # all columns but 'is_true', 'mutation', and 'Variant_Classification'
    features = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
    # log-transform and normalize the features
    features = features.apply(lambda x: np.log(1 + x))
    features = (features - features.mean()) / features.std()
    # add back the non-numeric columns
    features = pd.concat(
        [features, df[["mutation", "Variant_Classification", "is_true"]]], axis=1
    )
    return features
In [ ]:
df = log_and_normalize(df)
In [ ]:
df.head()
ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 AHDC1..ENSG00027245 ... ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134 mutation Variant_Classification is_true
0 1.447474 -0.884508 -0.121544 0.763355 0.326438 0.150100 0.544450 0.369058 -0.203933 -1.052204 ... -1.206587 0.476628 -1.005947 -0.642069 1.539171 -0.178439 0.726934 Frame_Shift_Ins A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... True
1 0.425827 0.714143 -0.103793 -0.752483 -1.235371 -1.741594 2.144800 0.657070 0.007487 0.437846 ... 0.685931 -0.011439 -1.243997 0.325185 0.441359 0.742193 1.510547 In_Frame_Del A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... True
2 -0.403549 -0.472462 1.258262 1.400088 -0.061798 0.871830 0.128522 -1.267703 -0.862442 1.198007 ... -0.590583 1.160248 -0.329160 -0.087477 0.921045 -0.355494 -0.996870 Frame_Shift_Del A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... True
3 -1.444494 0.777441 0.162707 0.051650 -0.133358 1.981800 -0.055030 -1.847446 -0.697533 -0.266645 ... 0.769839 0.365693 -0.258399 0.580993 0.056862 0.071359 0.254326 Frame_Shift_Del A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... True
4 1.223990 -1.133931 0.133754 -1.225917 -0.520306 -1.036720 -0.033708 2.109670 1.379861 0.245229 ... -0.558904 -2.703079 1.188400 -0.549939 -0.302067 1.021121 -0.523455 Frame_Shift_Del A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... True

5 rows × 554 columns

In [ ]:
plt.figure(figsize=(10, 6))
sns.violinplot(x='is_true', y=df.columns[0], data=df)
plt.xlabel('Functional')
plt.title("Distribution of the first gene expression level for functional and dysfunctional cells")
plt.show()
No description has been provided for this image

Correlations of genes with labels¶

Functional vs Dysfunctional¶

In [ ]:
from scipy import stats

correlations = {}
for gene in df.columns:
    if gene in ['is_true', 'mutation', 'Variant_Classification']:
        continue
    a = df['is_true'].to_numpy().astype(np.float64)
    b = df[gene].to_numpy().astype(np.float64)
    corr, pval = stats.pointbiserialr(a, b)
    correlations[gene] = corr
In [ ]:
corrs = list(correlations.values())
corrs = np.array(corrs)
np.abs(corrs).mean(), np.abs(corrs).std(), np.abs(corrs).max()
(0.027919123954705125, 0.01954913018439728, 0.08834416504312281)

Genes are not very correlated with the labels. Most of them would essentially provide noise for the classifier to see through and it's probably going to be better to remove them. Let's visualize the correlations as a curve.

In [ ]:
corrs.sort()
plt.plot(corrs)
[<matplotlib.lines.Line2D at 0x2a7bb9ad0>]
No description has been provided for this image
In [ ]:
# save genes in order of absolute correlation for later use

import json

sorted_correlations = sorted(correlations.items(), key=lambda x: np.abs(x[1]), reverse=True)
good_genes = [corr[0] for corr in sorted_correlations]
with open('good_genes_tf.txt', 'w') as f:
    json.dump(good_genes, f)

Mutation type¶

In [ ]:
from scipy import stats

correlations = {}
for gene in df.columns:
    if gene in ['is_true', 'mutation', 'Variant_Classification']:
        continue
    a = (df['mutation'] == 'Missense_Mutation').to_numpy().astype(np.float64)
    b = df[gene].to_numpy().astype(np.float64)
    corr, pval = stats.pointbiserialr(a, b)
    correlations[gene] = corr
In [ ]:
corrs = list(correlations.values())
corrs = np.array(corrs)
np.abs(corrs).mean()
np.abs(corrs).std()
np.abs(corrs).max()
0.09154977907792622
In [ ]:
corrs.sort()
plt.plot(corrs)
[<matplotlib.lines.Line2D at 0x2ab1a1ad0>]
No description has been provided for this image
In [ ]:
import json

sorted_correlations = sorted(correlations.items(), key=lambda x: np.abs(x[1]), reverse=True)
good_genes = [corr[0] for corr in sorted_correlations]
with open('good_genes_missense.txt', 'w') as f:
    json.dump(good_genes, f)

Classification¶

In [ ]:
from copy import deepcopy


df = deepcopy(df_full)
df = log_and_normalize(df)
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = df_full['mutation']

We check the proportions of each class

In [ ]:
y.value_counts() / len(y)
mutation
Missense_Mutation         0.642840
Nonsense_Mutation         0.129898
Frame_Shift_Del           0.093090
Splice_Site               0.066018
Frame_Shift_Ins           0.028734
In_Frame_Del              0.018048
Splice_Region             0.011874
Fusion_                   0.005937
In_Frame_Ins              0.003325
Translation_Start_Site    0.000237
Name: count, dtype: float64

Since non-missense mutations aren't well represented, we group them under a single label.

In [ ]:
y = y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
y.value_counts() / len(y)
mutation
1    0.64284
0    0.35716
Name: count, dtype: float64

We split the dataset into three subsets: train set to train our models, validation set to tune the hyperparameters and test set to assess the results.

In [ ]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)

We now proceed to fit classifiers to our data. We focus here and comment on the task of predicting missense vs non-missense mutation on the TCGA dataset: the results for the other 3 data-task combinations are similar and can be found in the Appendix.

To solve this task, we refrain from using deep learning and instead resort to classical supervised learning methods that are simpler and less prone to overfitting. This choice is motivated by the scarcity of available labeled samples.

We use k-fold cross validation to obtain statistically robust estimates of the validation error of different hyperparameter combinations.

Random forest¶

In [ ]:
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier


def hyperparameter_search(
    model,
    X_train,
    y_train,
    X_val,
    y_val,
    param_grid,
    search_type="grid",
    n_iter=10,
    scoring="accuracy",
    cv=5,
    verbose=2,
):
    """
    Perform hyperparameter search using grid search or random search.

    Parameters:
    - model: The machine learning model to tune.
    - X_train: Training feature set.
    - y_train: Training target variable.
    - X_val: Validation feature set.
    - y_val: Validation target variable.
    - param_grid: Dictionary of hyperparameters to search over.
    - search_type: 'grid' for GridSearchCV or 'random' for RandomizedSearchCV.
    - n_iter: Number of iterations for RandomizedSearchCV (ignored for GridSearchCV).
    - scoring: Scoring metric to use for evaluation.
    - cv: Number of cross-validation folds.
    - verbose: Verbosity level for the search.

    Returns:
    - best_model: The model with the best hyperparameters.
    - best_params: The best hyperparameters found during the search.
    - best_score: The best score achieved with the best hyperparameters.
    - all_results: DataFrame with hyperparameters and corresponding validation scores.
    """
    if search_type == "grid":
        search = GridSearchCV(
            estimator=model,
            param_grid=param_grid,
            scoring=scoring,
            cv=cv,
            return_train_score=True,
            verbose=verbose,
            n_jobs=4

        )
    elif search_type == "random":
        search = RandomizedSearchCV(
            estimator=model,
            param_distributions=param_grid,
            scoring=scoring,
            cv=cv,
            n_iter=n_iter,
            random_state=0,
            return_train_score=True,
            verbose=verbose,
            n_jobs=4
        )
    else:
        raise ValueError("search_type must be either 'grid' or 'random'")

    # can i use multi core?
    search.fit(X_train, y_train)

    best_model = search.best_estimator_
    best_params = search.best_params_
    best_score = search.best_score_

    val_predictions = best_model.predict(X_val)
    val_score = accuracy_score(y_val, val_predictions)

    print(f"Validation Score with best hyperparameters: {val_score}")

    # Collect all results
    results = search.cv_results_
    all_results = pd.DataFrame(results)

    return best_model, best_params, best_score, all_results

model = RandomForestClassifier()
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 5, 7, 10, 15, 20],
    'bootstrap': [True, False],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=2)
Fitting 3 folds for each of 24 candidates, totalling 72 fits
[CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time=   2.1s
[CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time=   2.1s
[CV] END ......bootstrap=True, max_depth=3, n_estimators=100; total time=   2.2s
[CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time=   4.9s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time=   4.3s
[CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time=   5.1s
[CV] END ......bootstrap=True, max_depth=3, n_estimators=200; total time=   5.2s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time=   3.4s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=100; total time=   2.9s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time=   5.5s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time=   5.5s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time=   3.6s
[CV] END ......bootstrap=True, max_depth=5, n_estimators=200; total time=   5.5s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time=   3.6s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=100; total time=   3.5s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time=   7.1s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time=   7.1s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time=   4.7s
[CV] END ......bootstrap=True, max_depth=7, n_estimators=200; total time=   7.4s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time=   5.0s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=100; total time=   5.4s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time=  10.6s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time=   6.4s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time=  10.4s
[CV] END .....bootstrap=True, max_depth=10, n_estimators=200; total time=  10.3s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time=   6.1s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=100; total time=   6.0s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time=   7.2s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time=  12.8s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time=  12.6s
[CV] END .....bootstrap=True, max_depth=15, n_estimators=200; total time=  13.0s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time=   6.4s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=100; total time=   6.6s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time=   2.7s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time=   2.7s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time=  13.0s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=100; total time=   2.9s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time=  13.8s
[CV] END .....bootstrap=True, max_depth=20, n_estimators=200; total time=  13.9s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time=   6.2s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time=   6.1s
[CV] END .....bootstrap=False, max_depth=3, n_estimators=200; total time=   5.6s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time=   4.4s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time=   4.4s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=100; total time=   4.4s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time=   5.7s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time=   8.5s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time=   8.5s
[CV] END .....bootstrap=False, max_depth=5, n_estimators=200; total time=   8.7s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time=   5.9s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=100; total time=   5.9s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time=  12.0s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time=  12.0s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time=   8.3s
[CV] END .....bootstrap=False, max_depth=7, n_estimators=200; total time=  12.6s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time=   8.8s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=100; total time=   8.7s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time=  15.9s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time=   9.4s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time=  15.2s
[CV] END ....bootstrap=False, max_depth=10, n_estimators=200; total time=  15.2s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time=   9.3s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=100; total time=   9.5s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time=  10.1s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time=  18.5s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time=  18.4s
[CV] END ....bootstrap=False, max_depth=15, n_estimators=200; total time=  18.7s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time=  10.1s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=100; total time=  10.3s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time=  17.9s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time=  16.9s
[CV] END ....bootstrap=False, max_depth=20, n_estimators=200; total time=  16.3s
Validation Score with best hyperparameters: 0.6484560570071259
In [ ]:
best_score, best_params
(0.6374702448506858, {'bootstrap': True, 'max_depth': 3, 'n_estimators': 200})
In [ ]:
all_results
mean_fit_time std_fit_time mean_score_time std_score_time param_bootstrap param_max_depth param_n_estimators params split0_test_score split1_test_score split2_test_score mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score mean_train_score std_train_score
0 2.112955 0.005620 0.028860 0.003517 True 3 100 {'bootstrap': True, 'max_depth': 3, 'n_estimat... 0.637578 0.637578 0.636364 0.637173 0.000572 5 0.637416 0.637416 0.639359 0.638064 0.000916
1 5.028177 0.136659 0.034389 0.012648 True 3 200 {'bootstrap': True, 'max_depth': 3, 'n_estimat... 0.637578 0.637578 0.637255 0.637470 0.000152 1 0.637416 0.637416 0.638468 0.637767 0.000496
2 3.512773 0.584872 0.017173 0.003729 True 5 100 {'bootstrap': True, 'max_depth': 5, 'n_estimat... 0.635797 0.638468 0.637255 0.637173 0.001092 4 0.662361 0.674833 0.670525 0.669240 0.005172
3 5.515238 0.007261 0.023251 0.001754 True 5 200 {'bootstrap': True, 'max_depth': 5, 'n_estimat... 0.637578 0.637578 0.635472 0.636876 0.000993 6 0.659243 0.668597 0.676313 0.668051 0.006980
4 3.552134 0.035117 0.015851 0.000998 True 7 100 {'bootstrap': True, 'max_depth': 7, 'n_estimat... 0.634907 0.633126 0.630125 0.632719 0.001973 11 0.767483 0.787082 0.776046 0.776871 0.008023
5 7.148444 0.140806 0.027100 0.001456 True 7 200 {'bootstrap': True, 'max_depth': 7, 'n_estimat... 0.634907 0.637578 0.630125 0.634203 0.003083 9 0.787973 0.781292 0.766251 0.778505 0.009084
6 4.993407 0.269705 0.025283 0.005352 True 10 100 {'bootstrap': True, 'max_depth': 10, 'n_estima... 0.634907 0.619768 0.622103 0.625593 0.006654 15 0.933630 0.935412 0.910953 0.926665 0.011134
7 10.375836 0.145891 0.064294 0.029963 True 10 200 {'bootstrap': True, 'max_depth': 10, 'n_estima... 0.635797 0.631345 0.623886 0.630343 0.004914 13 0.947884 0.946548 0.915850 0.936761 0.014796
8 6.138031 0.142559 0.025194 0.002995 True 15 100 {'bootstrap': True, 'max_depth': 15, 'n_estima... 0.625111 0.619768 0.599822 0.614901 0.010883 23 0.979510 0.978174 0.977293 0.978326 0.000911
9 12.712065 0.163985 0.046701 0.001244 True 15 200 {'bootstrap': True, 'max_depth': 15, 'n_estima... 0.634016 0.623330 0.614082 0.623809 0.008145 18 0.979510 0.978619 0.977738 0.978622 0.000723
10 6.724589 0.325609 0.028120 0.004744 True 20 100 {'bootstrap': True, 'max_depth': 20, 'n_estima... 0.629564 0.622440 0.607843 0.619949 0.009041 20 0.979510 0.979065 0.977738 0.978771 0.000753
11 13.505669 0.397157 0.049830 0.000677 True 20 200 {'bootstrap': True, 'max_depth': 20, 'n_estima... 0.628673 0.612645 0.620321 0.620546 0.006546 19 0.979510 0.979065 0.977738 0.978771 0.000753
12 2.749218 0.104987 0.013929 0.004260 False 3 100 {'bootstrap': False, 'max_depth': 3, 'n_estima... 0.637578 0.637578 0.637255 0.637470 0.000152 1 0.637416 0.637416 0.640695 0.638509 0.001545
13 5.937420 0.274229 0.021186 0.001302 False 3 200 {'bootstrap': False, 'max_depth': 3, 'n_estima... 0.637578 0.637578 0.637255 0.637470 0.000152 1 0.637416 0.637416 0.639804 0.638212 0.001126
14 4.377419 0.017964 0.014611 0.000963 False 5 100 {'bootstrap': False, 'max_depth': 5, 'n_estima... 0.637578 0.636687 0.630125 0.634797 0.003323 8 0.666370 0.681069 0.696349 0.681263 0.012240
15 8.535188 0.082705 0.023125 0.001524 False 5 200 {'bootstrap': False, 'max_depth': 5, 'n_estima... 0.636687 0.636687 0.636364 0.636580 0.000153 7 0.668151 0.675724 0.694568 0.679481 0.011107
16 5.826510 0.076224 0.018134 0.001061 False 7 100 {'bootstrap': False, 'max_depth': 7, 'n_estima... 0.640249 0.633126 0.626560 0.633312 0.005590 10 0.801336 0.813363 0.804541 0.806414 0.005085
17 12.183470 0.262563 0.041748 0.007930 False 7 200 {'bootstrap': False, 'max_depth': 7, 'n_estima... 0.636687 0.634907 0.625668 0.632421 0.004830 12 0.802673 0.819154 0.802315 0.808047 0.007855
18 8.546466 0.218283 0.032218 0.000473 False 10 100 {'bootstrap': False, 'max_depth': 10, 'n_estim... 0.635797 0.624221 0.615865 0.625294 0.008173 16 0.960802 0.963029 0.954586 0.959472 0.003573
19 15.394531 0.323630 0.040654 0.001421 False 10 200 {'bootstrap': False, 'max_depth': 10, 'n_estim... 0.635797 0.626892 0.619430 0.627373 0.006691 14 0.964365 0.966147 0.950134 0.960215 0.007166
20 9.346233 0.071505 0.027116 0.001459 False 15 100 {'bootstrap': False, 'max_depth': 15, 'n_estim... 0.634016 0.617988 0.604278 0.618761 0.012153 22 0.979510 0.979065 0.977738 0.978771 0.000753
21 18.498514 0.128986 0.046684 0.000652 False 15 200 {'bootstrap': False, 'max_depth': 15, 'n_estim... 0.634907 0.625111 0.612299 0.624106 0.009257 17 0.979510 0.979065 0.977738 0.978771 0.000753
22 10.118011 0.112264 0.026858 0.000628 False 20 100 {'bootstrap': False, 'max_depth': 20, 'n_estim... 0.606411 0.622440 0.610517 0.613123 0.006798 24 0.979510 0.979065 0.977738 0.978771 0.000753
23 16.992818 0.693453 0.039370 0.000675 False 20 200 {'bootstrap': False, 'max_depth': 20, 'n_estim... 0.628673 0.622440 0.606061 0.619058 0.009536 21 0.979510 0.979065 0.977738 0.978771 0.000753
In [ ]:
from sklearn.metrics import accuracy_score

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.68

The accuracy of the best performing random forest on the validation set is very bad: sligthly above the fraction of the dataset that belongs to the most represented class. However, the performance on the training set is good: the models are overfitting. Later on, to regularize training, we will try to select and to extract features, to reduce noise in the training data and hopefully improve performance.

To visualize the impact of different parameters on the validation performance, we use histograms.

In [ ]:
from typing import Dict


def plot_hyperparameter_search_results(
    all_results: pd.DataFrame, param_grid: Dict, score_metric="mean_test_score"
):
    """
    Plot the effect of each hyperparameter on the validation score. Fixed hyperparameters are averaged over.
    :param all_results: dataframe with search result, as output by hyperparameter_search
    :param param_grid: dictionary of all hyperparameters and their values
    :param score_metric: metric to plot
    """
    fig, axes = plt.subplots(len(param_grid), 1, figsize=(10, 3 * len(param_grid)))
    for ax, (param, values) in zip(axes, param_grid.items()):
        means = all_results.groupby(f"param_{param}")[score_metric].mean()
        print(means)
        # i want a bar for each value of the hyperparameter
        ax.bar(means.index, means.values)
        # ax.hist(x=means.index, weights=means.values, bins=len(values), rwidth=0.8)
        ax.set_title(f"Effect of {param} on {score_metric}")
        ax.set_xlabel(param)
        ax.set_ylabel(score_metric)
    plt.tight_layout()

plot_hyperparameter_search_results(all_results, param_grid)
param_n_estimators
100    0.627522
200    0.630021
Name: mean_test_score, dtype: float64
param_max_depth
3     0.637396
5     0.636356
7     0.633164
10    0.627151
15    0.620394
20    0.618169
Name: mean_test_score, dtype: float64
param_bootstrap
False    0.628314
True     0.629230
Name: mean_test_score, dtype: float64
No description has been provided for this image

We use a confusion matrix to understand the types of errors our best performing random forest is making.

In [ ]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

Other Supervised methods¶

In [ ]:
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(max_iter=10000)
param_grid = {
    "penalty": ['l2', None],
    "C": [0.001, 0.005, 0.01, 0.1],

}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=2)

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")

plot_hyperparameter_search_results(all_results, param_grid)
plt.show()

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 8 candidates, totalling 24 fits
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
[CV] END ................................C=0.001, penalty=l2; total time=   0.2s
[CV] END ................................C=0.001, penalty=l2; total time=   0.1s
[CV] END ................................C=0.001, penalty=l2; total time=   0.2s
[CV] END ................................C=0.005, penalty=l2; total time=   0.3s
[CV] END ................................C=0.005, penalty=l2; total time=   0.3s
[CV] END ..............................C=0.001, penalty=None; total time=   0.8s
[CV] END ..............................C=0.001, penalty=None; total time=   0.7s
[CV] END ..............................C=0.001, penalty=None; total time=   0.8s
[CV] END ................................C=0.005, penalty=l2; total time=   0.2s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
[CV] END .................................C=0.01, penalty=l2; total time=   0.3s
[CV] END ..............................C=0.005, penalty=None; total time=   0.8s
[CV] END .................................C=0.01, penalty=l2; total time=   0.3s
[CV] END ..............................C=0.005, penalty=None; total time=   1.0s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
[CV] END ..............................C=0.005, penalty=None; total time=   1.0s
[CV] END .................................C=0.01, penalty=l2; total time=   0.3s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
[CV] END ...............................C=0.01, penalty=None; total time=   0.6s
[CV] END ..................................C=0.1, penalty=l2; total time=   0.5s
[CV] END ...............................C=0.01, penalty=None; total time=   0.7s
[CV] END ...............................C=0.01, penalty=None; total time=   0.6s
[CV] END ..................................C=0.1, penalty=l2; total time=   0.3s
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
/Users/fedezara/Desktop/Uni/ML_Lab/ML-lab/.env/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1208: UserWarning: Setting penalty=None will ignore the C and l1_ratio parameters
  warnings.warn(
[CV] END ..................................C=0.1, penalty=l2; total time=   0.5s
[CV] END ................................C=0.1, penalty=None; total time=   0.5s
[CV] END ................................C=0.1, penalty=None; total time=   0.6s
[CV] END ................................C=0.1, penalty=None; total time=   0.5s
Validation Score with best hyperparameters: 0.6555819477434679
Test accuracy of best model: 0.68
param_penalty
l2    0.60726
Name: mean_test_score, dtype: float64
param_C
0.001    0.603770
0.005    0.594864
0.010    0.590857
0.100    0.581205
Name: mean_test_score, dtype: float64
No description has been provided for this image
No description has been provided for this image
In [ ]:
from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier()
param_grid = {
    'n_neighbors': [1, 3, 5, 10, 15],
    'weights': ['uniform', 'distance'],
    'p': [1, 2]
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=2)

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")

plot_hyperparameter_search_results(all_results, param_grid)
plt.show()

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 20 candidates, totalling 60 fits
[CV] END ................n_neighbors=1, p=1, weights=uniform; total time=   1.2s
[CV] END ...............n_neighbors=1, p=1, weights=distance; total time=   1.2s
[CV] END ................n_neighbors=1, p=1, weights=uniform; total time=   1.2s
[CV] END ................n_neighbors=1, p=1, weights=uniform; total time=   1.2s
[CV] END ................n_neighbors=1, p=2, weights=uniform; total time=   0.2s
[CV] END ................n_neighbors=1, p=2, weights=uniform; total time=   0.1s
[CV] END ...............n_neighbors=1, p=2, weights=distance; total time=   0.1s
[CV] END ................n_neighbors=1, p=2, weights=uniform; total time=   0.2s
[CV] END ...............n_neighbors=1, p=2, weights=distance; total time=   0.1s
[CV] END ...............n_neighbors=1, p=2, weights=distance; total time=   0.1s
[CV] END ...............n_neighbors=1, p=1, weights=distance; total time=   1.6s
[CV] END ...............n_neighbors=1, p=1, weights=distance; total time=   1.6s
[CV] END ................n_neighbors=3, p=1, weights=uniform; total time=   1.4s
[CV] END ................n_neighbors=3, p=1, weights=uniform; total time=   1.3s
[CV] END ................n_neighbors=3, p=1, weights=uniform; total time=   1.2s
[CV] END ...............n_neighbors=3, p=1, weights=distance; total time=   1.2s
[CV] END ................n_neighbors=3, p=2, weights=uniform; total time=   0.2s
[CV] END ................n_neighbors=3, p=2, weights=uniform; total time=   0.2s
[CV] END ...............n_neighbors=3, p=2, weights=distance; total time=   0.2s
[CV] END ................n_neighbors=3, p=2, weights=uniform; total time=   0.2s
[CV] END ...............n_neighbors=3, p=1, weights=distance; total time=   1.9s
[CV] END ...............n_neighbors=3, p=2, weights=distance; total time=   0.2s
[CV] END ...............n_neighbors=3, p=1, weights=distance; total time=   1.9s
[CV] END ...............n_neighbors=3, p=2, weights=distance; total time=   0.2s
[CV] END ................n_neighbors=5, p=1, weights=uniform; total time=   1.5s
[CV] END ................n_neighbors=5, p=1, weights=uniform; total time=   1.5s
[CV] END ...............n_neighbors=5, p=1, weights=distance; total time=   1.5s
[CV] END ................n_neighbors=5, p=1, weights=uniform; total time=   1.5s
[CV] END ................n_neighbors=5, p=2, weights=uniform; total time=   0.1s
[CV] END ................n_neighbors=5, p=2, weights=uniform; total time=   0.2s
[CV] END ...............n_neighbors=5, p=2, weights=distance; total time=   0.2s
[CV] END ................n_neighbors=5, p=2, weights=uniform; total time=   0.2s
[CV] END ...............n_neighbors=5, p=2, weights=distance; total time=   0.2s
[CV] END ...............n_neighbors=5, p=2, weights=distance; total time=   0.2s
[CV] END ...............n_neighbors=5, p=1, weights=distance; total time=   1.8s
[CV] END ...............n_neighbors=5, p=1, weights=distance; total time=   1.8s
[CV] END ...............n_neighbors=10, p=1, weights=uniform; total time=   1.6s
[CV] END ...............n_neighbors=10, p=1, weights=uniform; total time=   1.5s
[CV] END ..............n_neighbors=10, p=1, weights=distance; total time=   1.1s
[CV] END ...............n_neighbors=10, p=1, weights=uniform; total time=   1.1s
[CV] END ...............n_neighbors=10, p=2, weights=uniform; total time=   0.1s
[CV] END ...............n_neighbors=10, p=2, weights=uniform; total time=   0.1s
[CV] END ..............n_neighbors=10, p=2, weights=distance; total time=   0.1s
[CV] END ...............n_neighbors=10, p=2, weights=uniform; total time=   0.2s
[CV] END ..............n_neighbors=10, p=2, weights=distance; total time=   0.1s
[CV] END ..............n_neighbors=10, p=2, weights=distance; total time=   0.1s
[CV] END ..............n_neighbors=10, p=1, weights=distance; total time=   1.2s
[CV] END ..............n_neighbors=10, p=1, weights=distance; total time=   1.2s
[CV] END ...............n_neighbors=15, p=1, weights=uniform; total time=   1.5s
[CV] END ...............n_neighbors=15, p=1, weights=uniform; total time=   1.5s
[CV] END ...............n_neighbors=15, p=1, weights=uniform; total time=   1.6s
[CV] END ..............n_neighbors=15, p=1, weights=distance; total time=   1.6s
[CV] END ...............n_neighbors=15, p=2, weights=uniform; total time=   0.1s
[CV] END ...............n_neighbors=15, p=2, weights=uniform; total time=   0.1s
[CV] END ..............n_neighbors=15, p=2, weights=distance; total time=   0.2s
[CV] END ...............n_neighbors=15, p=2, weights=uniform; total time=   0.2s
[CV] END ..............n_neighbors=15, p=2, weights=distance; total time=   0.1s
[CV] END ..............n_neighbors=15, p=2, weights=distance; total time=   0.1s
[CV] END ..............n_neighbors=15, p=1, weights=distance; total time=   1.4s
[CV] END ..............n_neighbors=15, p=1, weights=distance; total time=   1.5s
Validation Score with best hyperparameters: 0.6460807600950119
Test accuracy of best model: 0.63
param_n_neighbors
1     0.558192
3     0.582242
5     0.592410
10    0.593970
15    0.611267
Name: mean_test_score, dtype: float64
param_weights
distance    0.587468
uniform     0.587765
Name: mean_test_score, dtype: float64
param_p
1    0.588537
2    0.586696
Name: mean_test_score, dtype: float64
No description has been provided for this image
No description has been provided for this image
In [ ]:
from sklearn.svm import SVC

model = SVC()
param_grid = {
    'C': [0.001, 0.005, 0.01, 0.1],
    'gamma': ['scale', 'auto'],
    "kernel": ['rbf', 'poly', 'sigmoid', 'linear'],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=2)

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")

plot_hyperparameter_search_results(all_results, param_grid)
plt.show()

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 32 candidates, totalling 96 fits
[CV] END ..................C=0.001, gamma=scale, kernel=poly; total time=   4.5s
[CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time=   4.5s
[CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time=   4.5s
[CV] END ...................C=0.001, gamma=scale, kernel=rbf; total time=   4.5s
[CV] END ..................C=0.001, gamma=scale, kernel=poly; total time=   4.5s
[CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time=   4.3s
[CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time=   4.4s
[CV] END ..................C=0.001, gamma=scale, kernel=poly; total time=   4.6s
[CV] END ...............C=0.001, gamma=scale, kernel=sigmoid; total time=   4.1s
[CV] END ................C=0.001, gamma=scale, kernel=linear; total time=   4.3s
[CV] END ................C=0.001, gamma=scale, kernel=linear; total time=   4.3s
[CV] END ................C=0.001, gamma=scale, kernel=linear; total time=   4.3s
[CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time=   4.4s
[CV] END ...................C=0.001, gamma=auto, kernel=poly; total time=   4.1s
[CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time=   4.2s
[CV] END ....................C=0.001, gamma=auto, kernel=rbf; total time=   4.3s
[CV] END ...................C=0.001, gamma=auto, kernel=poly; total time=   4.4s
[CV] END ...................C=0.001, gamma=auto, kernel=poly; total time=   4.4s
[CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time=   4.2s
[CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time=   4.3s
[CV] END ................C=0.001, gamma=auto, kernel=sigmoid; total time=   4.1s
[CV] END .................C=0.001, gamma=auto, kernel=linear; total time=   4.1s
[CV] END .................C=0.001, gamma=auto, kernel=linear; total time=   4.2s
[CV] END .................C=0.001, gamma=auto, kernel=linear; total time=   4.2s
[CV] END ..................C=0.005, gamma=scale, kernel=poly; total time=   4.7s
[CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time=   5.3s
[CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time=   5.3s
[CV] END ...................C=0.005, gamma=scale, kernel=rbf; total time=   5.2s
[CV] END ..................C=0.005, gamma=scale, kernel=poly; total time=   4.9s
[CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time=   4.6s
[CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time=   4.7s
[CV] END ..................C=0.005, gamma=scale, kernel=poly; total time=   5.2s
[CV] END ...............C=0.005, gamma=scale, kernel=sigmoid; total time=   4.5s
[CV] END ................C=0.005, gamma=scale, kernel=linear; total time=   4.4s
[CV] END ................C=0.005, gamma=scale, kernel=linear; total time=   4.3s
[CV] END ................C=0.005, gamma=scale, kernel=linear; total time=   4.2s
[CV] END ...................C=0.005, gamma=auto, kernel=poly; total time=   4.2s
[CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time=   5.1s
[CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time=   4.9s
[CV] END ....................C=0.005, gamma=auto, kernel=rbf; total time=   5.0s
[CV] END ...................C=0.005, gamma=auto, kernel=poly; total time=   4.6s
[CV] END ...................C=0.005, gamma=auto, kernel=poly; total time=   5.0s
[CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time=   4.6s
[CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time=   4.6s
[CV] END ................C=0.005, gamma=auto, kernel=sigmoid; total time=   4.8s
[CV] END .................C=0.005, gamma=auto, kernel=linear; total time=   4.6s
[CV] END .................C=0.005, gamma=auto, kernel=linear; total time=   4.6s
[CV] END .................C=0.005, gamma=auto, kernel=linear; total time=   4.4s
[CV] END ...................C=0.01, gamma=scale, kernel=poly; total time=   4.6s
[CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time=   5.6s
[CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time=   5.3s
[CV] END ....................C=0.01, gamma=scale, kernel=rbf; total time=   5.3s
[CV] END ...................C=0.01, gamma=scale, kernel=poly; total time=   4.6s
[CV] END ...................C=0.01, gamma=scale, kernel=poly; total time=   4.8s
[CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time=   4.2s
[CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time=   4.3s
[CV] END ................C=0.01, gamma=scale, kernel=sigmoid; total time=   4.8s
[CV] END .................C=0.01, gamma=scale, kernel=linear; total time=   4.7s
[CV] END .................C=0.01, gamma=scale, kernel=linear; total time=   4.5s
[CV] END .................C=0.01, gamma=scale, kernel=linear; total time=   4.4s
[CV] END ....................C=0.01, gamma=auto, kernel=poly; total time=   4.7s
[CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time=   5.6s
[CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time=   5.2s
[CV] END .....................C=0.01, gamma=auto, kernel=rbf; total time=   5.3s
[CV] END ....................C=0.01, gamma=auto, kernel=poly; total time=   4.6s
[CV] END ....................C=0.01, gamma=auto, kernel=poly; total time=   5.1s
[CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time=   5.0s
[CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time=   4.9s
[CV] END .................C=0.01, gamma=auto, kernel=sigmoid; total time=   5.0s
[CV] END ..................C=0.01, gamma=auto, kernel=linear; total time=   4.5s
[CV] END ..................C=0.01, gamma=auto, kernel=linear; total time=   4.5s
[CV] END ..................C=0.01, gamma=auto, kernel=linear; total time=   4.5s
[CV] END ....................C=0.1, gamma=scale, kernel=poly; total time=   4.6s
[CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time=   5.7s
[CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time=   5.5s
[CV] END .....................C=0.1, gamma=scale, kernel=rbf; total time=   5.2s
[CV] END ....................C=0.1, gamma=scale, kernel=poly; total time=   4.7s
[CV] END ....................C=0.1, gamma=scale, kernel=poly; total time=   4.9s
[CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time=   4.5s
[CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time=   4.5s
[CV] END .................C=0.1, gamma=scale, kernel=sigmoid; total time=   4.5s
[CV] END ..................C=0.1, gamma=scale, kernel=linear; total time=   5.9s
[CV] END ..................C=0.1, gamma=scale, kernel=linear; total time=   6.2s
[CV] END ..................C=0.1, gamma=scale, kernel=linear; total time=   6.0s
[CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time=   4.8s
[CV] END .....................C=0.1, gamma=auto, kernel=poly; total time=   4.4s
[CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time=   5.0s
[CV] END ......................C=0.1, gamma=auto, kernel=rbf; total time=   5.1s
[CV] END .....................C=0.1, gamma=auto, kernel=poly; total time=   4.4s
[CV] END .....................C=0.1, gamma=auto, kernel=poly; total time=   4.9s
[CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time=   4.4s
[CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time=   4.6s
[CV] END ..................C=0.1, gamma=auto, kernel=sigmoid; total time=   4.2s
[CV] END ...................C=0.1, gamma=auto, kernel=linear; total time=   6.0s
[CV] END ...................C=0.1, gamma=auto, kernel=linear; total time=   5.4s
[CV] END ...................C=0.1, gamma=auto, kernel=linear; total time=   5.1s
Validation Score with best hyperparameters: 0.6484560570071259
Test accuracy of best model: 0.68
param_C
0.001    0.636950
0.005    0.631903
0.010    0.629083
0.100    0.621734
Name: mean_test_score, dtype: float64
param_gamma
auto     0.629918
scale    0.629918
Name: mean_test_score, dtype: float64
param_kernel
linear     0.607779
poly       0.636951
rbf        0.637470
sigmoid    0.637470
Name: mean_test_score, dtype: float64
No description has been provided for this image
No description has been provided for this image

scGPT embeddings¶

scGPT is a foundation model created by Cui et al. ("scGPT: Towards Building a Foundation Model for Single-Cell Multi-omics Using Generative AI.") for single-cell genomics that can be used to generate embeddings for genes and cells. We use its embeddings as input data for a random forest classifier.

In [ ]:
# constant definition and utility functions

set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins

device = torch.device("cpu")


def preprocess_df(name, extension='csv'):
    """
    Preparing dataset for scGPT
    """
    csv_file = f'data/{name}.{extension}'
    df = pd.read_csv(csv_file)
    df = df.drop(columns=['is_true', "Variant_Classification", "mutation"])
    df = df.clip(lower=0)
    df.columns = [s.split("..")[0] for s in df.columns]
    df.to_csv(f'data/{name}_processed_gpt.{extension}', index=False)


def load_model(model_dir="model_params"):
    # the weights can be downloaded from https://github.com/bowang-lab/scGPT/tree/main?tab=readme-ov-file#pretrained-scGPT-checkpoints
    # we used the whole human model
    model_config_file = model_dir + "/args.json"
    model_file = model_dir + "/best_model.pt"
    vocab_file = model_dir + "/vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    print(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]

    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        pad_value=pad_value,
        n_input_bins=n_input_bins,
    )

    try:
        model.load_state_dict(torch.load(model_file, map_location=device))
        print(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file, map_location=device)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            print(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    gene2idx = vocab.get_stoi()

    return model, gene2idx, vocab_file


def load_and_preprocess(name, extension='csv'):
    adata = sc.read(f"data/{name}_processed_gpt.{extension}", cache=False)
    ori_batch_col = "batch"
    data_is_raw = True

    # Preprocess the data following the scGPT data pre-processing pipeline
    preprocessor = Preprocessor(
        use_key="X",  # the key in adata.layers to use as raw data
        filter_gene_by_counts=3,  # step 1
        filter_cell_by_counts=False,  # step 2
        normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
        result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
        log1p=data_is_raw,  # 4. whether to log1p the normalized data
        result_log1p_key="X_log1p",
        subset_hvg=n_hvg,  # 5. whether to subset the raw data to highly variable genes
        hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
        binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
        result_binned_key="X_binned",  # the key in adata.layers to store the binned data
    )
    preprocessor(adata)
    return adata

def compute_gene_embeddings(gene2idx, model_gpt, adata):
    """
    Compute embedding of each gene in adata
    """
    gene_ids = np.array([id for id in gene2idx.values()])
    gene_embeddings = model_gpt.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
    gene_embeddings = gene_embeddings.detach().cpu().numpy()

    # Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
    gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if
                       gene in adata.var.index.tolist()}
    print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))
    return gene_embeddings




def get_cell_embeddings(adata, gene_embeddings):
    """
    Compute cell embeddings of adata
    """
    cell_embeddings_l = []
    for cell_idx in tqdm.tqdm(range(adata.shape[0])):
        cell_expression = adata[cell_idx].X.toarray().flatten()
        cell_embedding = np.zeros_like(next(iter(gene_embeddings.values())))

        for gene_idx, expression_level in enumerate(cell_expression):
            gene_name = adata.var.index[gene_idx]
            if gene_name in gene_embeddings:
                cell_embedding += gene_embeddings[gene_name] * expression_level

        cell_embeddings_l.append(cell_embedding)

    cell_embeddings = np.array(cell_embeddings_l)
    print('Computed embeddings for {} cells.'.format(cell_embeddings.shape[0]))
    return cell_embeddings

We preprocess the data, load the model and compute the embeddings.

In [ ]:
# Load the pre-trained scGPT model and preprocess the data
model_gpt, gene2idx, vocab_file = load_model()
preprocess_df("CCLE_labels")
preprocess_df("TCGA_labels")
Resume model from model_params/best_model.pt, the model args will override the config model_params/args.json.
Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
Loading params encoder.enc_norm.weight with shape torch.Size([512])
Loading params encoder.enc_norm.bias with shape torch.Size([512])
Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading params value_encoder.linear1.bias with shape torch.Size([512])
Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading params value_encoder.linear2.bias with shape torch.Size([512])
Loading params value_encoder.norm.weight with shape torch.Size([512])
Loading params value_encoder.norm.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.0.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.1.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.1.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.2.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.2.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.3.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.3.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.4.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.4.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.5.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.5.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.6.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.6.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.7.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.7.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.8.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.8.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.9.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.9.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.10.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.10.norm2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.linear2.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.11.linear2.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm1.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.weight with shape torch.Size([512])
Loading params transformer_encoder.layers.11.norm2.bias with shape torch.Size([512])
Loading params decoder.fc.0.weight with shape torch.Size([512, 512])
Loading params decoder.fc.0.bias with shape torch.Size([512])
Loading params decoder.fc.2.weight with shape torch.Size([512, 512])
Loading params decoder.fc.2.bias with shape torch.Size([512])
Loading params decoder.fc.4.weight with shape torch.Size([1, 512])
Loading params decoder.fc.4.bias with shape torch.Size([1])
In [ ]:
adataCCLE = load_and_preprocess("CCLE_labels")
adataTCGA = load_and_preprocess("TCGA_labels")

gene_embeddings_ccle = compute_gene_embeddings(gene2idx, model_gpt,adataCCLE)
gene_embeddings_tcga = compute_gene_embeddings(gene2idx, model_gpt, adataTCGA)

cell_embeddings_ccle = get_cell_embeddings(adataCCLE, gene_embeddings_ccle)
cell_embeddings_tcga = get_cell_embeddings(adataTCGA, gene_embeddings_tcga)

X_train_embedded, X_test_embedded, y_train, y_test = train_test_split(cell_embeddings_tcga, y, test_size=0.2, random_state=0)
X_val_embedded, X_test_embedded, y_val, y_test = train_test_split(X_test_embedded, y_test, test_size=0.5, random_state=0)
scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - WARNING - No batch_key is provided, will use all cells for HVG selection.
scGPT - INFO - Binning data ...
scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - WARNING - No batch_key is provided, will use all cells for HVG selection.
scGPT - INFO - Binning data ...
Retrieved gene embeddings for 579 genes.
Retrieved gene embeddings for 545 genes.
100%|██████████| 924/924 [00:02<00:00, 443.69it/s]
Computed embeddings for 924 cells.
100%|██████████| 4211/4211 [00:08<00:00, 494.44it/s]
Computed embeddings for 4211 cells.

In [ ]:
model = RandomForestClassifier()
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 5, 7, 10, 15, 20],
    'bootstrap': [True, False],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model,    X_train_embedded, y_train, X_val_embedded, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=2)

y_pred = best_model.predict(X_test_embedded)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")

plot_hyperparameter_search_results(all_results, param_grid)
plt.show()

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
Fitting 3 folds for each of 24 candidates, totalling 72 fits
Validation Score with best hyperparameters: 0.6484560570071259
Test accuracy of best model: 0.68
param_n_estimators
100    0.623564
200    0.624950
Name: mean_test_score, dtype: float64
param_max_depth
3     0.636728
5     0.634204
7     0.627673
10    0.622624
15    0.614681
20    0.609634
Name: mean_test_score, dtype: float64
param_bootstrap
False    0.624233
True     0.624282
Name: mean_test_score, dtype: float64
No description has been provided for this image
No description has been provided for this image

As we can see, the results are completely in line with the previous ones, with the accuracy of the best model still being only slightly above randomly guessing the most represented class.

Feature Selection using correlations¶

Import previously computed features sorted by correletion with the labels

In [ ]:
import json
with open('good_genes_missense.txt') as f:
    ordered_genes = json.load(f)
In [ ]:
N_values = [10, 25, 50, 75, 100, 125, 150, 200, 300]
val_scores = []
train_scores = []
for N in N_values:
    good_genes = ordered_genes[:N]
    X_train_good = X_train[good_genes]
    X_val_good = X_val[good_genes]
    model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
    model.fit(X_train_good, y_train)
    val_scores.append(model.score(X_val_good, y_val))
    train_scores.append(model.score(X_train_good, y_train))
In [ ]:
plt.plot(N_values, val_scores, label='Validation')
plt.plot(N_values, train_scores, label='Train')
plt.xlabel('Number of genes')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy vs number of genes selected')
plt.show()
No description has been provided for this image
In [ ]:
N = 75
good_genes = ordered_genes[:N]
X_train_good = X_train[good_genes]
X_val_good = X_val[good_genes]
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_good, y_train)
X_test_good = X_test[good_genes]
model.score(X_test_good, y_test)
0.6469194312796208
In [ ]:
y_pred = model.predict(X_test_good)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

Feature Extraction using PCA¶

In [ ]:
N_values = [10, 25, 50, 75, 100, 125, 150, 200, 300]
val_scores = []
train_scores = []
for N in N_values:
    pca = PCA(n_components=N)
    pca.fit(X_train)
    X_train_pca = pca.transform(X_train)
    X_val_pca = pca.transform(X_val)
    model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
    model.fit(X_train_pca, y_train)
    val_scores.append(model.score(X_val_pca, y_val))
    train_scores.append(model.score(X_train_pca, y_train))
In [ ]:
plt.plot(N_values, val_scores, label='Validation')
plt.plot(N_values, train_scores, label='Train')
plt.xlabel('Number of PCA components')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy vs number of PCA components')
plt.show()
No description has been provided for this image
In [ ]:
N = 25
pca = PCA(n_components=N)
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_val_pca = pca.transform(X_val)
model = RandomForestClassifier(n_estimators=200, max_depth=10, bootstrap=True)
model.fit(X_train_pca, y_train)
X_test_pca = pca.transform(X_test)
model.score(X_test_pca, y_test)
0.6658767772511849
In [ ]:
y_pred = model.predict(X_test_pca)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

Class weights¶

In [ ]:
class_weights = {
    1: len(y_train) / (y_train == 1).sum(),
    0: len(y_train) / (y_train == 0).sum()
}

N = 75
good_genes = ordered_genes[:N]
X_train_good = X_train[good_genes]
X_val_good = X_val[good_genes]
model = RandomForestClassifier(n_estimators=200, max_depth=5, bootstrap=True, class_weight=class_weights)
model.fit(X_train_good, y_train)
X_test_good = X_test[good_genes]
model.score(X_test_good, y_test)
0.514218009478673
In [ ]:
y_pred = model.predict(X_test_good)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

By weighting the loss terms of different classed inversely to the class proportions, we have obtained more balanced predictions. However, for each given class still about 50% of the time the model is wrong.

Conclusions¶

From our analysis, it emerged that the data is too noisy to classify any of the two tasks with accuracy meaningfully better than what obtained by always outputting the most frequent label. This is in line with what was found in the paper. Our models learn, but they overfit the training set and achieve poor generalization performance. This is in spite of the fact that we explicitly regularized our models through hyperparameter choices in our grid searches.

Furthermore, our attempts at feature extraction and selection through PCA and correlation analysis were not helpful in improving performance. We were expecting the regularizing effect of these methods, which remove some of the least informative features and should thus help reducing noise and mitigating overfitting, to help with our tasks. We conclude that there is no strong enough signal in the data to solve our tasks with meaningful accuracy. Indeed, from inspection of the confusion matrices many of our models end up simply outputting the most frequent class every time.

Even providing weights inversely proportional to the class frequencies to weigh the loss during training, although we get more balanced predictions, the performance is terrible, with each class being mispredicted about 50% of the time.

Appendix¶

We show here some results for the other task and dataset combinations. As anticipated, the conclusions are essentially the same as those presented for classification of mutation type on the TCGA dataset. For this reason, we avoid including in this report all our explorations and all the models we trained, to avoid being repetitive and making this notebook too heavy.

CCLE Dataset¶

In [42]:
csv_file = 'data/CCLE_labels.csv'
df = pd.read_csv(csv_file)
df_full = df
df.head()
Out[42]:
Variant_Classification MAD1L1..ENSG00000002822. ITGA3..ENSG00000005884. MYH13..ENSG00000006788. GAS7..ENSG00000007237. REV3L..ENSG00000009413. TSPAN9..ENSG00000011105. RNF216..ENSG00000011275. CEP68..ENSG00000011523. BRCA1..ENSG00000012048. ... BGLAP..ENSG00000242252. MICAL3..ENSG00000243156. FMN1..ENSG00000248905. GATC..ENSG00000257218. CUX1..ENSG00000257923. BAHCC1..ENSG00000266074. PRAG1..ENSG00000275342. UHRF1..ENSG00000276043. is_true mutation
0 1__639V_URINARY_TRACT_Missense_Mutation_c.(742... 3102 8389 0 3 3104 1698 5130 1687 4486 ... 9 8144 3539 2944 6288 3160 140 23040 False Missense_Mutation
1 1__BL41_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Mis... 5645 312 0 14 4925 35 4856 1110 10004 ... 33 8346 40 4463 23703 133 4115 38105 False Missense_Mutation
2 1__CA46_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Mis... 6967 2113 0 113 8180 115 3648 2871 7615 ... 34 6574 234 4132 18149 265 1580 4790 False Missense_Mutation
3 1__CAL29_URINARY_TRACT_Missense_Mutation_c.(84... 1882 24720 0 41 1809 1731 4349 759 2988 ... 9 4380 1202 2996 19658 924 3914 10313 False Missense_Mutation
4 1__CI1_HAEMATOPOIETIC_AND_LYMPHOID_TISSUE_Miss... 3139 1444 0 619 4744 25 6039 1757 12484 ... 14 6447 44 3915 4945 25 4066 20850 False Missense_Mutation

5 rows × 582 columns

Mutation type¶

In [24]:
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['mutation']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == 'Missense_Mutation' else 0)
In [25]:
(y == 1).sum() / len(y)
Out[25]:
0.6536796536796536
In [26]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
In [27]:
from sklearn.ensemble import RandomForestClassifier


model = RandomForestClassifier()
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits
Validation Score with best hyperparameters: 0.6847826086956522
In [28]:
best_score, best_params
Out[28]:
(0.6495232766092843, {'max_depth': 3, 'n_estimators': 200})
In [29]:
from sklearn.metrics import accuracy_score

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.66
In [30]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

Functional vs Dysfunctional¶

In [48]:
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['is_true']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == True else 0)
In [49]:
(y == 1).sum() / len(y)
Out[49]:
0.32575757575757575
In [50]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
In [51]:
from sklearn.ensemble import RandomForestClassifier


model = RandomForestClassifier()
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits
Validation Score with best hyperparameters: 0.7282608695652174
In [52]:
best_score, best_params
Out[52]:
(0.669821050437225, {'max_depth': 3, 'n_estimators': 100})
In [53]:
from sklearn.metrics import accuracy_score

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.68
In [54]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image

TCGA Dataset¶

In [55]:
csv_file = 'data/TCGA_labels.csv'
df = pd.read_csv(csv_file)
df_full = df
df.head()
Out[55]:
Variant_Classification ABCB9..ENSG00023457 ABLIM1..ENSG0003983 ACTA2..ENSG00059 ACTB..ENSG00060 ADORA2B..ENSG000136 ADRB2..ENSG000154 AEBP2..ENSG000121536 AEN..ENSG00064782 AGAP1..ENSG000116987 ... ZCCHC2..ENSG00054877 ZDHHC14..ENSG00079683 ZFP36L1..ENSG000677 ZMAT3..ENSG00064393 ZMIZ1..ENSG00057178 ZMIZ2..ENSG00083637 ZMYND8..ENSG00023613 ZNF561..ENSG00093134 is_true mutation
0 A129Vfs*20_TCGA-66-2785_Frame_Shift_Ins_17:g.7... 376.831000 1358.86000 2471.580000 143602.00000 159.674000 63.136500 946.639000 626.477000 344.195000 ... 323.344000 75.356400 8558.040000 43.991900 1783.300000 5320.570000 1018.330000 821.181000 True Frame_Shift_Ins
1 A138_P142del_TCGA-25-2393_In_Frame_Del_17:g.75... 198.244448 5367.62179 2528.570328 77726.97678 19.656121 2.579692 2130.976296 732.991931 386.605718 ... 228.638412 322.247574 6446.509718 36.542642 3207.438557 3213.116903 1688.261865 1149.407697 True In_Frame_Del
2 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 117.516000 1936.34000 14533.700000 185841.00000 95.490700 191.866000 766.578000 256.410000 239.611000 ... 230.672000 121.132000 12726.800000 74.270600 2496.910000 4005.300000 923.961000 391.689000 True Frame_Shift_Del
3 A138Cfs*27_TCGA-55-6980_Frame_Shift_Del_17:g.7... 60.747000 5667.60000 3560.420000 107645.00000 86.834700 1047.620000 698.413000 186.741000 262.372000 ... 638.609000 343.604000 8024.280000 78.431400 3746.030000 2692.810000 1168.070000 670.402000 True Frame_Shift_Del
4 A138Cfs*27_TCGA-D8-A13Y_Frame_Shift_Del_17:g.7... 327.477000 1096.61000 3430.480000 64166.60000 51.837300 9.491300 706.010000 1617.540000 821.366000 ... 806.811000 124.118000 1350.690000 237.649000 1885.860000 2283.400000 1967.630000 480.043000 True Frame_Shift_Del

5 rows × 554 columns

Functional vs Dysfunctional¶

In [56]:
df = deepcopy(df_full)
df = log_and_normalize(df)
y = df['is_true']
df = df.drop(columns=["is_true", "mutation", "Variant_Classification"])
y = y.apply(lambda x: 1 if x == True else 0)
In [57]:
(y == 1).sum() / len(y)
Out[57]:
0.4490619805271907
In [58]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=0)
In [59]:
from sklearn.ensemble import RandomForestClassifier


model = RandomForestClassifier()
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 5, 7, 10, 15, 20],
}
best_model, best_params, best_score, all_results = hyperparameter_search(model, X_train, y_train, X_val, y_val,
                                                                         param_grid, search_type='grid', cv=3,
                                                                         verbose=1)
Fitting 3 folds for each of 12 candidates, totalling 36 fits
Validation Score with best hyperparameters: 0.5629453681710214
In [60]:
best_score, best_params
Out[60]:
(0.5463172397591757, {'max_depth': 3, 'n_estimators': 200})
In [61]:
from sklearn.metrics import accuracy_score

y_pred = best_model.predict(X_test)
accuracy_score(y_test, y_pred)
print(f"Test accuracy of best model: {accuracy_score(y_test, y_pred):.2f}")
Test accuracy of best model: 0.57
In [62]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.show()
No description has been provided for this image